{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example we will see how to train a SOM to create a map of hand-written digits using the <a href=\"https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits\">UCI ML hand-written digits datasets</a>.\n",
    "\n",
    "First, we'll 1) load the data using the sklearn wrapper, 2) scale the data, 3) train the som."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from minisom import MiniSom\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from sklearn import datasets\n",
    "from sklearn.preprocessing import scale\n",
    "\n",
    "# load the digits dataset from scikit-learn\n",
    "digits = datasets.load_digits(n_class=10)\n",
    "data = digits.data  # matrix where each row is a vector that represent a digit.\n",
    "data = scale(data)\n",
    "num = digits.target  # num[i] is the digit represented by data[i]\n",
    "\n",
    "som = MiniSom(30, 30, 64, sigma=4,\n",
    "              learning_rate=0.5, neighborhood_function='triangle')\n",
    "som.pca_weights_init(data)\n",
    "som.train(data, 5000, random_order=True, verbose=True)  # random training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not that each input vector for the SOM represents the entire image obtained reshaping the original image of dimension 8-by-8 into a vector of 64 elements. The images in input are gray scale.\n",
    "\n",
    "We can now place each digit on the map represented by the SOM:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 8))\n",
    "wmap = {}\n",
    "im = 0\n",
    "for x, t in zip(data, num):  # scatterplot\n",
    "    w = som.winner(x)\n",
    "    wmap[w] = im\n",
    "    plt. text(w[0]+.5,  w[1]+.5,  str(t),\n",
    "              color=plt.cm.rainbow(t / 10.), fontdict={'weight': 'bold',  'size': 11})\n",
    "    im = im + 1\n",
    "plt.axis([0, som.get_weights().shape[0], 0,  som.get_weights().shape[1]])\n",
    "plt.savefig('resulting_images/som_digts.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10), facecolor='white')\n",
    "cnt = 0\n",
    "for j in reversed(range(20)):  # images mosaic\n",
    "    for i in range(20):\n",
    "        plt.subplot(20, 20, cnt+1, frameon=False,  xticks=[],  yticks=[])\n",
    "        if (i, j) in wmap:\n",
    "            plt.imshow(digits.images[wmap[(i, j)]],\n",
    "                       cmap='Greys', interpolation='nearest')\n",
    "        else:\n",
    "            plt.imshow(np.zeros((8, 8)),  cmap='Greys')\n",
    "        cnt = cnt + 1\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('resulting_images/som_digts_imgs.png')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
